import gym
import gym3
from gym.spaces import Discrete
from continual_rl.experiments.tasks.image_task import ImageTask
import cv2
import time
import numpy as np
import matplotlib.pyplot as plt
from matplotlib import animation


class ProcgenActionWrapper(gym.ActionWrapper):
    def __init__(self, env, env_name, action_size):
        super().__init__(env)
        self.action_space = Discrete(action_size)
        self.env_name = env_name

    def action(self, act):
        act = int(act)
        assert 0 <= act < self.action_space.n
        if self.env_name == "bigfish-v0":
            actions = [4, 5, 3, 1, 7, 2, 0, 8, 6]  # 原始的上下左右动作
            # 5: up, 3: down, 1: left, 7: right， 2: left up, 4: noop, 6: right down, 8: right up, 0: left down
            return actions[act]
        else:
            print("!!!Warning: Action mapping not defined for this environment!!!")
        return act


def make_procgen(env_name, num_levels=0, start_level=0, distribution_mode="easy", action_size=None):
    env = gym.make(
        f"procgen:procgen-{env_name}",
        num_levels=num_levels,
        start_level=start_level,
        distribution_mode=distribution_mode,
        # render_mode="rgb_array"
    )
    if action_size is not None:
        env = ProcgenActionWrapper(env, env_name, action_size)
    return env


def get_single_procgen_task(task_id, action_space_id, env_name, num_timesteps, eval_mode=False, action_size=None, **kwargs):
    return ImageTask(
        task_id=task_id,
        action_space_id=action_space_id,
        env_spec=lambda: make_procgen(env_name, action_size=action_size, **kwargs),
        num_timesteps=num_timesteps,
        time_batch_size=1,  # no framestack
        eval_mode=eval_mode,
        image_size=[64, 64],
        grayscale=False,
    )


def get_key_action():
    # 映射键盘输入到动作
    key = input("Enter action: ").strip().lower()
    return key


def display_frame(frame):
    plt.imshow(frame)
    plt.axis('off')
    plt.show()


def main():
    env_name = "bigfish-v0"  # 选择一个 procgen 环境
    env = make_procgen(env_name, action_size=9)
    obs = env.reset()

    for i in range(20):
        obs, reward, done, info = env.step(0)
        print(f"Action: 0, Reward: {reward}, Done: {done}")
        if done:
            obs = env.reset()

    while True:
        display_frame(env.render(mode='rgb_array'))

        action = get_key_action()
        if action is not None:
            obs, reward, done, info = env.step(action)
            print(f"Action: {action}, Reward: {reward}, Done: {done}")

            if done:
                obs = env.reset()

        time.sleep(0.1)  # 控制帧率


if __name__ == "__main__":
    main()